/*
 * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 * DEALINGS IN THE SOFTWARE.
 */

#include <onnx/onnx_pb.h>
#include <google/protobuf/text_format.h>
#include <sstream>

namespace {

// Removes raw data from the text representation of an ONNX model
inline void remove_raw_data_strings(std::string& s) {
  std::string::size_type beg = 0;
  const std::string key = "raw_data: \"";
  const std::string sub = "...";
  while( (beg = s.find(key, beg)) != std::string::npos ) {
    beg += key.length();
    std::string::size_type end = beg - 1;
    // Note: Must skip over escaped end-quotes
    while( s[(end = s.find("\"", ++end)) - 1] == '\\' ) {}
    if( end - beg > 128 ) { // Only remove large data strings
      s.replace(beg, end - beg, "...");
    }
    beg += sub.length();
  }
}

// Removes float_data, int32_data etc. from the text representation of an ONNX model
inline std::string remove_repeated_data_strings(std::string& s) {
  std::istringstream iss(s);
  std::ostringstream oss;
  bool is_repeat = false;
  for( std::string line; std::getline(iss, line); ) {
    if(  line.find("float_data:") != std::string::npos ||
         line.find("int32_data:") != std::string::npos ||
         line.find("int64_data:") != std::string::npos ) {
      if( !is_repeat ) {
        is_repeat = true;
        oss << line.substr(0, line.find(":") + 1) << " ...\n";
      }
    } else {
      is_repeat = false;
      oss << line << "\n";
    }
  }
  return oss.str();
}

} // anonymous namespace

inline std::string pretty_print_onnx_to_string(::google::protobuf::Message const& message) {
  std::string s;
  ::google::protobuf::TextFormat::PrintToString(message, &s);
  remove_raw_data_strings(s);
  s = remove_repeated_data_strings(s);
  return s;
}

inline std::ostream& operator<<(std::ostream& stream, ::ONNX_NAMESPACE::ModelProto const& message) {
  stream << pretty_print_onnx_to_string(message);
  return stream;
}

inline std::ostream& operator<<(std::ostream& stream, ::ONNX_NAMESPACE::NodeProto const& message) {
  stream << pretty_print_onnx_to_string(message);
  return stream;
}
